#coding=utf-8
# Copyright (c) 2016 Tinydot. inc.
# All Rights Reserved.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

import tensorflow as tf
import numpy as np
import os
from ml_feature.util.Operation import JsonOperation
from ml_feature.util.Operation import load_jsonfile, load_pbfile
from ml_feature.util.Transfer import load_pbtxt_file, transfer_format, transfer_o_pbtxt
from ml_feature.util.QtDialog import PlaceholderDialog
from ml_feature.util import get_config_view_hidden_ops, get_config_view_show_features, get_config_json, bytes_to_human, get_config
from PyQt5.QtCore import QThread
from google.protobuf import text_format
import time
import platform

tf.logging.set_verbosity(tf.logging.ERROR)


class TimeThread(QThread):
    """
    打开mlc文件耗时较多，在进度条上加上用时，避免用户误操作
    """
    def __init__(self, emit, emit_value):
        super(TimeThread, self).__init__()
        self.emit = emit
        self.emit_value = emit_value
        self.start_time = time.time()

    def run(self):
        while True:
            self.emit('Merge intermediate file to graph_def %d'%(int(time.time()-self.start_time)),self.emit_value)
            time.sleep(1)


def strip_consts(graph_def, max_const_size=32):
    """Strip large constant values from graph_def."""
    strip_def = tf.GraphDef()
    for n0 in graph_def.node:
        n = strip_def.node.add()
        n.MergeFrom(n0)
        if n.op == 'Const':
            tensor = n.attr['value'].tensor
            size = len(tensor.tensor_content)
            if size > max_const_size:
                tensor.tensor_content = bytes("<stripped %d bytes>"%size,'utf-8')
    return strip_def


def rename_nodes(graph_def, rename_func):
    res_def = tf.GraphDef()
    for n0 in graph_def.node:
        n = res_def.node.add()
        n.MergeFrom(n0)
        n.name = rename_func(n.name)
        for i, s in enumerate(n.input):
            n.input[i] = rename_func(s) if s[0]!='^' else '^'+rename_func(s[1:])
    return res_def


def parse_graph_to_html(graph_def, max_const_size=32, base_url="/files/tensorboard/", raw=False):
    """
    将graph_def解析为结构流程图，Tensorboard
    :param graph_def: 被解析的模型对象
    :param max_const_size:
    :param base_url: 资源文件的根目录
    :param raw:
    :return:
    """
    if hasattr(graph_def, 'as_graph_def'):
        graph_def = graph_def.as_graph_def()
    strip_def = strip_consts(graph_def, max_const_size=max_const_size)

    if raw is False:
        code = """
            <script>
              function load() {{
                  document.getElementById("{id}").pbtxt = {data};
              }};
            </script>
            <div style="height:600px">
              <tf-graph-basic id="{id}"></tf-graph-basic>
            </div>
            <link rel="import" href="{base_url}tf-graph-basic.build.html" onload="load()">
        """.format(data=repr(str(strip_def)),
               id='graph'+str(int(np.random.rand()*1000)),
               base_url=base_url,
              )
    else:
        code = """
            <div style="height:600px">
              <tf-graph-basic id="{id}"></tf-graph-basic>
            </div>
            <script>
              document.getElementById("{id}").pbtxt = {data};
            </script>
        """.format(data=repr(str(strip_def)),
               id='graph'+str(int(np.random.rand()*1000)),
               base_url=base_url,
              )
    return code


def load_pb_graph(model_file, emit=None, emit_value=None):
    """
    加载模型文件为graph_def
    :param model_file: 目标模型文件路径
    :param emit: 控制进度条的方法
    :param emit_value: 进入此方法时，进度条的进度
    :return: 加载模型文件得到的graph_def
    """
    model_file_suffix = os.path.splitext(model_file)[1]
    if model_file_suffix =='.pb':
        if emit:
            emit('load pb file',emit_value+2)
        graph_def = load_pbfile(model_file,flag=True)
        placeholder_valuedict = dict()
    elif model_file_suffix == '.pbtxt':
        if emit:
            emit('load pbtxt file',emit_value+2)
        graph_def = load_pbtxt_file(model_file, flag=True)
        placeholder_valuedict = dict()
    elif model_file_suffix == '.mlc':
        if emit:
            emit('load mlc file',emit_value+1)
        content, placeholder_valuedict = transfer_o_pbtxt(model_file,emit=emit,emit_value=emit_value+1)
        graph_def = tf.GraphDef()
        if emit:
            emit('Merge intermediate file to graph_def',emit_value+2)
        thread = TimeThread(emit,emit_value+2)
        thread.start()
        text_format.Merge(content,graph_def)
        thread.terminate()
    if emit:
        emit('graph_def import',emit_value+3)
    return graph_def, placeholder_valuedict


def drawtable_optype(result, feature_dict, op_ban_list):
    '''
    生成依据op_type分行的html table
    :param result: 存放需要显示的信息
    :param feature_dict: 列名顺序，dict，eg. {feature_name_python : feature_name_chinese}
    :param op_ban_list: list，禁止出现的op类型
    :return: 返回生成的html代码
    '''
    info_table = '''
    <table class="table table-hover" style="margin:0 auto">
    '''
    table_summary = dict()
    op_set = set()
    feature_dict_key_list = list(feature_dict.keys())
    if len(result.keys()) > 0:# 防止出现，result为空，但出现表头的现象
        info_table += '<tr><th>op_type</th>'
        for col_name in feature_dict_key_list:
            info_table += '<th>'+feature_dict[col_name]['col_name']+'</th>'
        info_table += '</tr>'
        # 表头部分结束
        # 以下为表的内容部分
        for op_type in result.keys():
            if op_type in op_ban_list:
                continue
            result_op_dict = result[op_type]
            if result_op_dict:
                info_table += '<tr><td>'+str(op_type)+'</td>'
                for feature in feature_dict_key_list:
                    info_table += '<td>'+bytes_to_human(result_op_dict.get(feature, '-'), factor=feature_dict[feature]['unit_int'], unint_str=feature_dict[feature]['unit_str'])+'</td>'
                    if feature in table_summary:
                        table_summary[feature] += result_op_dict.get(feature, 0)
                    else:
                        table_summary[feature] = result_op_dict.get(feature, 0)
                info_table += '</tr>'
            else:
                # 说明配置文件中此op没定义，所以全是 '-'
                op_set.add(op_type)
        # 配置文件没有定义的op
        for none_op_type in op_set:
            info_table += '<tr><td>'+str(none_op_type)+'</td>'
            for col_name in feature_dict_key_list:
                info_table += '<td>-</td>'
            info_table += '</tr>'
        # Summary
        info_table += '<tr><td>Summary</td>'
        for feature in feature_dict_key_list:
            info_table += '<td>'+bytes_to_human(table_summary.get(feature, '-'), factor=feature_dict[feature]['unit_int'], unint_str=feature_dict[feature]['unit_str'])+'</td>'
        info_table += '</tr>'
    info_table+='</table>'
    return info_table


def drawtable_sequence(sequence, feature_dict):
    '''
    按照sequence里存放的信息顺序，生成html table
    :param sequence: 存放需要生成的op信息，list(dict)结构
    :param feature_dict: 列的顺序
    :return: 返回生成的table的html代码
    '''
    sequence_table='''
    <table class="table table-hover" style="margin:0 auto">
    '''
    feature_dict_key_list = list(feature_dict.keys())
    if len(sequence) > 0:# 防止出现，sequence为空，出现表头的问题
        sequence_table += '<tr><th>index</th><th>op_type</th>'
        for col_name in feature_dict_key_list:
            sequence_table += '<th>'+str(feature_dict[col_name]['col_name'])+'</th>'
        sequence_table+='</tr>'
        # 表头部分结束，以下为表体部分
        for index,sequence_op in enumerate(sequence):
            if type(sequence_op) == type(dict()):
                sequence_table += '<tr><td>'+str(index+1)+'</td><td>'+str(sequence_op.get('op_type', '-'))+'</td>'
                for feature in feature_dict_key_list:
                    sequence_table += '<td>'+bytes_to_human(sequence_op.get(feature, '-'), factor=feature_dict[feature]['unit_int'], unint_str=feature_dict[feature]['unit_str'])+'</td>'
                sequence_table += '</tr>'
            else:
                sequence_table += '<tr><td>'+str(index+1)+'</td><td>'+str(sequence_op)+'</td>'
                for feature in feature_dict_key_list:
                    sequence_table += '<td>-</td>'
                sequence_table += '</tr>'
    sequence_table += '</table>'
    return sequence_table


def init_graph_def_placeholder_shape(graph_def, init_values):
    """
    初始化 graph_def 中 placeholder 的 shape
    eg: new_graph_def = init_graph_def_placeholder_shape(old_graph_def,
                                                        {
                                                          "placeholder1": Tensor1.get_shape(),
                                                          "placeholder2": Tensor2.get_shape(),
                                                         })
    Args:
        graphd_ef: 原始 graph_def
        init_values: 初始化参数， eg {"placeholder_name_1": shape}

    Returns: 初始化后的 graph_def
    """
    for node in graph_def.node:
        if '_output_shapes' in node.attr:
            del node.attr['_output_shapes']
        if(node.op=="Placeholder"):

            if len(node.attr["shape"].shape.dim) < 1:
                assert node.name in init_values, "%s <%s>(should be inited) not find in init_values%s"%(
                                                                              node.op,
                                                                              node.name,
                                                                              [n for n in init_values])
            else:
                if node.name not in init_values:
                    continue
        else:
            continue

        if "shape" in node.attr:
            del node.attr["shape"]

        node.attr["shape"].shape.CopyFrom(init_values[node.name].as_proto())

    tmp_graph = tf.Graph()
    with tf.Session(graph=tmp_graph):
        tf.import_graph_def(graph_def, name="")

        return tf.get_default_graph().as_graph_def(add_shapes=True)


def get_placehoder_prevalue(graph_def, placeholder_valuedict):
    """
    给graph_def中的Placeholder设置预设值
    :param graph_def: 目标对象，给此graph_def设置预设值
    :param placeholder_valuedict: 调用此函数时Placeholder已知的预设值
    :return:返回设置好预设值，并计算所有op的输出shape，计算完后的graph和graph_def
    """
    if len(placeholder_valuedict.keys())==0:
        # placeholder_valuedict长度为0，说明Placeholder并没有预设值，需要用户的输入
        flag = True
        placeholder_default_value = dict()
        for node in graph_def.node:
            if node.op == 'Placeholder':
                if len(node.attr['shape'].shape.dim) < 1:
                    placeholder_default_value[node.name] = None
                else:
                    placeholder_default_value[node.name] = [int(i.split(':')[1]) for i in repr(node.attr['shape'].shape.dim)[1:-1].split(',')]
        while flag:
            try:
                if len(placeholder_default_value.keys())!=0:
                    # 说明有placeholder需要用户输入，否则不弹框
                    dialog, ok = PlaceholderDialog.getPreValue(placeholder_default_value)
                    if ok:
                        for key in dialog.keys():
                            arr = dialog[key].split(',')
                            val = [int(i) for i in arr]
                            placeholder_valuedict[key] = tf.ones(val,tf.int32).get_shape()
                        new_graph_def = init_graph_def_placeholder_shape(graph_def, placeholder_valuedict)
                    else:
                        return False, None
                else:
                    new_graph_def = init_graph_def_placeholder_shape(graph_def, placeholder_valuedict)
            except Exception as e:
                flag = True
            else:
                flag = False
    else:
        try:
            new_graph_def = init_graph_def_placeholder_shape(graph_def, placeholder_valuedict)
        except Exception as e:
            return False, None
    graph = tf.Graph()
    with graph.as_default():
        tf.import_graph_def(new_graph_def, name='')
    return graph, new_graph_def


def make_tensor_html_file(graph,graph_def, resource_dir, placeholder_valuedict=dict(), pre_html_data="", post_html_data="",emit=None,emit_value=None):
    """
    生成html页面 ！！！important！！！
    """
    doc_dir = os.path.abspath(resource_dir)
    if emit:
        emit('parse graph_def',emit_value+1)

    data = get_config_json()
    result = dict()
    sequence = []
    op_ban_list = get_config_view_hidden_ops()
    feature_dict = get_config_view_show_features()
    op_num = 0
    op_amount = len(graph.get_operations())
    with tf.Session(graph=graph) as sess:
        for op in sess.graph.get_operations():
            if emit:
                emit('analysis op in graph_def(%d/%d)'%(op_num,op_amount),emit_value+2)
                op_num+=1
            if op.type in data.keys():
                json_op = JsonOperation(op, data[op.type],op_ban_list,feature_dict)
                if op.type not in result:
                    result[op.type] = dict()
                json_op.resolve(result[op.type],sequence)
            else:
                result[op.type] = None
                sequence.append(op.type)
    if emit:
        emit('arrange the features to table',emit_value+3)
    html_data = parse_graph_to_html(rename_nodes(graph_def, lambda s: "/".join(s.split('_', 1))),
                                    base_url="file://%s/" % doc_dir,
                                    raw=True)
    info_table = drawtable_optype(result, feature_dict, op_ban_list)
    sequence_table = drawtable_sequence(sequence, feature_dict)
    #加入bootstrap样式
    bootstrap_show = '''
<style>
.nav-tabs > li{
    display:inline-block;
    float:none;
}
.nav-tabs{
    text-align:center;
}
.table{
    table-layout:fixed;
}
</style>
<div class="page-header">
<h2>Feature</h2>
</div>
<ul class="nav nav-tabs" id="table_switch">
  <li class="active"><a href="#first">statistical information</a></li>
  <li><a href="#second">op sequence information</a></li>
</ul>
<div class="tab-content">
<div class="tab-pane active" id="first">

    '''+info_table+'''
</div>
<div class="tab-pane" id="second">

    '''+sequence_table+'''
</div>
</div>
<script>
$(document).ready(function() {
  $('#table_switch a').click(function (e) {
    e.preventDefault();
    $(this).tab('show');
  })
})
</script>

    '''
    #reload按钮
    reload_btn = '''
<div class="page-header">
<h2>
Reload
<small>you can reset the placeholder value</small>
</h2>
</div>
<style>
#reload_table td{
    padding:15px;
}
</style>
<table border="0" style="margin:0 auto" id="reload_table">

    '''
    for key in placeholder_valuedict.keys():
        reload_btn+='<tr><td>'+key+'<td>'
        for j in placeholder_valuedict[key]:
            reload_btn+='<td><input type="text" size="5" value="'+str(j)+'"/></td>'
        reload_btn+='</tr>'
    reload_btn+='<tr><td><button id="reload_btn">reload</button></td></tr></table>\n'
    reload_btn+='''
<script src="qrc:///qtwebchannel/qwebchannel.js"></script>
<script>
$(document).ready(function(){
    new QWebChannel(qt.webChannelTransport, function(channel){
        window.qt_mainwindow = channel.objects.qt_mainwindow;
    })
    $("#reload_btn").click(function(){
        var flag = true;
        var content = '{"placeholder_datalist":{';
        var arr = [];
        $("#reload_table").find("tr").each(function(index, row){
            var row_content = '';
            var row_arr = [];
            if($(row).children("td").length > 1){
                row_content = row_content + '"' +$(row).children("td:eq(0)").text() +'":'
                $(row).find("input").each(function(i, value){
                    if($.trim($(value).val())==''){
                        flag = false;
                    }else{
                        row_arr.push('"'+$.trim($(value).val())+'"');
                    }
                });
                row_content = row_content+'['+row_arr.join(',')+']';
                arr.push(row_content);
            }
        });
        content = content + arr.join(',') + "}}"
        if(flag==false){
            window.qt_mainwindow.test_js_call_qt("False");
        }else{
            window.qt_mainwindow.test_js_call_qt(content);
        }
    })
})
</script>
    '''

    # 读取模版数据
    if emit:
        emit('save to index_tmp_page.html',emit_value+4)
    with open(os.path.join(doc_dir,"meta.html"), "r+") as f:
        license = f.read()
    system_str = platform.system()
    if system_str == 'Windows':
        meta_str = '<meta charset="GBK">'
    else:
        meta_str = '<meta charset="UTF-8">'
    with open(os.path.join(doc_dir,"base_template.html"), "r+") as fh:
        page_content = fh.read()

    tmp_index_page = os.path.join(doc_dir,"index_tmp_page.html")
    with open(tmp_index_page, "w+") as fh:
        fh.write(license+meta_str+page_content+pre_html_data+html_data+reload_btn+bootstrap_show+post_html_data+"</body></html>")
    return tmp_index_page, sequence
